{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving and loading models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from chemprop.models.utils import save_model, load_model\n",
    "from chemprop.models.model import MPNN\n",
    "from chemprop.models.multi import MulticomponentMPNN\n",
    "from chemprop import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example buffer to save to and load from, to avoid creating new files when running this notebook. A real use case would probably save to and read from a file like `model.pt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "\n",
    "saved_model = io.BytesIO()\n",
    "\n",
    "# from pathlib import Path\n",
    "# saved_model = Path(\"model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A valid model save file is a dictionary containing the hyper parameters and state dict of the model. `torch` is used to pickle the dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())\n",
    "\n",
    "save_model(saved_model, model)\n",
    "\n",
    "# model_dict = {\"hyper_parameters\": model.hparams, \"state_dict\": model.state_dict()}\n",
    "# torch.save(model_dict, saved_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`lightning` will also automatically create checkpoint files during training. These `.ckpt` files are like `.pt` model files, but also contain information about training and can be used to restart training. See the `lightning` documentation for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "from lightning.pytorch import Trainer\n",
    "\n",
    "checkpointing = ModelCheckpoint(\n",
    "    dirpath=\"mycheckpoints\",\n",
    "    filename=\"best-{epoch}-{val_loss:.2f}\",\n",
    "    monitor=\"val_loss\",\n",
    "    mode=\"min\",\n",
    "    save_last=True,\n",
    ")\n",
    "trainer = Trainer(callbacks=[checkpointing])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MPNN` and `MulticomponentMPNN` each have a class method to load a model from either a model file `.pt` or a checkpoint file `.ckpt`. The method to load from a file works for either model files or checkpoint files, but won't load the saved training information from a checkpoint file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Need to set the buffer stream position to the beginning, not necessary if using a file\n",
    "saved_model.seek(0)\n",
    "\n",
    "model = MPNN.load_from_file(saved_model)\n",
    "\n",
    "# Other options\n",
    "# model = MPNN.load_from_checkpoint(saved_model)\n",
    "# model = MulticomponentMPNN.load_from_file(saved_model)\n",
    "# model = MulticomponentMPNN.load_from_checkpoint(saved_model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
